import torch
import torch.nn as nn
from easycore.common.registry import Registry

OPTIMIZER_REGISTRY = Registry('optimizer')

def build_optimizer(cfg, model):
    return OPTIMIZER_REGISTRY.get(cfg.OPTIMIZER.NAME)(cfg, model)

@OPTIMIZER_REGISTRY.register()
def build_AdamW_optimizer(cfg, model):
    optimizer = torch.optim.AdamW(model.parameters(), cfg.OPTIMIZER.LR)
    return optimizer

@OPTIMIZER_REGISTRY.register()
def build_SGD_optimizer(cfg, model):
    optimizer = torch.optim.SGD(model.parameters(), cfg.OPTIMIZER.LR, momentum=cfg.OPTIMIZER.MOMENTUM, nesterov=True)
    